import math
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
from linearmodels.iv import IV2SLS


MISSING_CODES = {7, 8, 9, 77, 88, 99, 777, 888, 999, 7777, 8888, 9999}


def clean_numeric(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    s = s.mask(s.isin(MISSING_CODES))
    return s


def recode_yes_no(series: pd.Series) -> pd.Series:
    s = clean_numeric(series)
    return s.map({1: 1.0, 2: 0.0})


def zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m = s.mean()
    sd = s.std(ddof=0)
    if pd.isna(sd) or sd <= 0:
        return s * 0
    return (s - m) / sd


def absorb_two_way(df: pd.DataFrame, cols: list[str], fe_a: str, fe_b: str, max_iter: int = 30, tol: float = 1e-10) -> pd.DataFrame:
    """
    Absorb two sets of fixed effects via alternating projections (double-demeaning).
    This avoids constructing huge dummy matrices (region FE is large).
    """
    x = df[cols].to_numpy(dtype=float, copy=True)
    a_codes = pd.Categorical(df[fe_a]).codes.astype(np.int32)
    b_codes = pd.Categorical(df[fe_b]).codes.astype(np.int32)
    a_n = int(a_codes.max() + 1)
    b_n = int(b_codes.max() + 1)

    a_counts = np.bincount(a_codes, minlength=a_n).astype(float)
    b_counts = np.bincount(b_codes, minlength=b_n).astype(float)

    for _ in range(max_iter):
        prev = x.copy()

        # subtract FE-A means
        for j in range(x.shape[1]):
            sums = np.bincount(a_codes, weights=x[:, j], minlength=a_n)
            means = sums / a_counts
            x[:, j] = x[:, j] - means[a_codes]

        # subtract FE-B means
        for j in range(x.shape[1]):
            sums = np.bincount(b_codes, weights=x[:, j], minlength=b_n)
            means = sums / b_counts
            x[:, j] = x[:, j] - means[b_codes]

        delta = float(np.nanmax(np.abs(x - prev)))
        if delta < tol:
            break

    return pd.DataFrame(x, columns=cols, index=df.index)


@dataclass(frozen=True)
class OutcomeSpec:
    name: str
    label: str
    kind: str  # "binary" | "continuous"


def build_dataset(root: Path) -> pd.DataFrame:
    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    bartik = pd.read_csv(root / "data_external" / "bartik_region_year.csv")

    ess = ess[ess["regunit"].isin([1, 2])].copy()
    ess["region"] = ess["region"].astype(str).str.upper().str.strip()
    ess = ess[ess["region"] != "99999"].copy()

    bartik["region"] = bartik["region"].astype(str).str.upper().str.strip()
    bartik["survey_year"] = pd.to_numeric(bartik["survey_year"], errors="coerce").astype("Int64")
    bartik["z_bartik_density"] = pd.to_numeric(bartik["z_bartik_density"], errors="coerce")

    df = ess.merge(bartik[["region", "survey_year", "z_bartik_density"]], on=["region", "survey_year"], how="left")

    # Endogenous
    df["netusoft"] = clean_numeric(df["netusoft"])

    # Participation outcomes
    df["y_vote"] = recode_yes_no(df["vote"])
    df["y_sgnptit"] = recode_yes_no(df["sgnptit"])
    df["y_pbldmn"] = recode_yes_no(df["pbldmn"])
    df["y_bctprd"] = recode_yes_no(df["bctprd"])
    df["y_contplt"] = recode_yes_no(df["contplt"])
    df["y_badge"] = recode_yes_no(df["badge"])
    df["y_pstplonl"] = recode_yes_no(df["pstplonl"]) if "pstplonl" in df.columns else np.nan
    part_cols = ["y_sgnptit", "y_pbldmn", "y_bctprd", "y_contplt", "y_badge", "y_pstplonl"]
    df["y_part_index"] = df[part_cols].mean(axis=1, skipna=True)
    df.loc[df[part_cols].isna().all(axis=1), "y_part_index"] = np.nan

    # Mechanism candidates
    df["nwspol"] = clean_numeric(df["nwspol"])
    df["log_nwspol"] = (df["nwspol"].clip(lower=0)).map(lambda x: math.log1p(x) if pd.notna(x) else x)

    # Placebos (predetermined-ish)
    df["gndr_num"] = clean_numeric(df["gndr"])  # 1/2
    df["gndr"] = df["gndr_num"]
    df["brncntr_bin"] = recode_yes_no(df["brncntr"]) if "brncntr" in df.columns else np.nan

    # Controls
    df["agea"] = clean_numeric(df["agea"])
    df["hinctnta"] = clean_numeric(df["hinctnta"])
    df["eduyrs"] = clean_numeric(df["eduyrs"])

    # Education indicator: high vs low
    df["edu_high"] = np.where(df["eduyrs"].notna() & (df["eduyrs"] > 12), 1.0, np.nan)
    df.loc[df["eduyrs"].notna() & (df["eduyrs"] <= 12), "edu_high"] = 0.0

    df["ct"] = df["cntry"].astype(str) + "_" + df["survey_year"].astype("Int64").astype(str)

    # Keep sample where instrument exists + edu exists (IV needs instrument)
    df = df[df["z_bartik_density"].notna()].copy()
    df = df[df["edu_high"].notna()].copy()
    df = df[df["netusoft"].notna()].copy()

    # Standardize within estimation sample
    df["z_bartik_density"] = zscore(df["z_bartik_density"])

    return df


def fit_iv_interaction_absorb(df: pd.DataFrame, y: str, controls: list[str]) -> IV2SLS:
    """
    Two-way FE absorbed (region FE + country×survey_year FE):
      y = b1*netusoft + b2*(netusoft*edu_high) + controls + regionFE + ctFE + e
    Instruments (Bartik):
      z_bartik_density and z_bartik_density*edu_high
    """
    cols = [y, "netusoft", "edu_high", "z_bartik_density", "ct", "region", *controls]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"No rows after dropna for outcome={y}")

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    d["z_bartik_x_edu"] = d["z_bartik_density"] * d["edu_high"]

    # Residualize all model variables on region and ct FE
    model_cols = [y, "netusoft", "netusoft_x_edu", "z_bartik_density", "z_bartik_x_edu", *controls, "edu_high"]
    r = absorb_two_way(d, model_cols, fe_a="region", fe_b="ct")

    y_r = r[y]
    exog = r[controls + ["edu_high"]]
    endog = r[["netusoft", "netusoft_x_edu"]]
    instr = r[["z_bartik_density", "z_bartik_x_edu"]]

    mod = IV2SLS(y_r, exog=exog, endog=endog, instruments=instr)
    return mod.fit(cov_type="clustered", clusters=d["region"])


def summarize_result(res, y: str, label: str) -> dict:
    b1 = float(res.params.get("netusoft", np.nan))
    b2 = float(res.params.get("netusoft_x_edu", np.nan))
    se1 = float(res.std_errors.get("netusoft", np.nan))
    se2 = float(res.std_errors.get("netusoft_x_edu", np.nan))

    diag = res.first_stage.diagnostics
    fs_net = diag.loc["netusoft"] if "netusoft" in diag.index else diag.iloc[0]
    fs_int = diag.loc["netusoft_x_edu"] if "netusoft_x_edu" in diag.index else diag.iloc[-1]

    return {
        "outcome": y,
        "label": label,
        "nobs": int(res.nobs),
        "coef_lowEdu": b1,
        "se_lowEdu": se1,
        "coef_highEdu": b1 + b2,
        "coef_deltaHighMinusLow": b2,
        "se_deltaHighMinusLow": se2,
        "fsF_netusoft": float(fs_net.get("f.stat", np.nan)),
        "fsF_netusoft_x_edu": float(fs_int.get("f.stat", np.nan)),
    }


def to_markdown_table(df: pd.DataFrame) -> str:
    cols = [
        "label",
        "nobs",
        "coef_lowEdu",
        "se_lowEdu",
        "coef_highEdu",
        "coef_deltaHighMinusLow",
        "se_deltaHighMinusLow",
        "fsF_netusoft",
        "fsF_netusoft_x_edu",
    ]
    d = df[cols].copy()
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        d[c] = d[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    return d.to_markdown(index=False)


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)

    df = build_dataset(root)

    outcomes = [
        OutcomeSpec("y_vote", "Vote (yes/no)", "binary"),
        OutcomeSpec("y_part_index", "Participation index (mean of non-vote items)", "continuous"),
        OutcomeSpec("y_sgnptit", "Signed petition", "binary"),
        OutcomeSpec("y_contplt", "Contacted politician/official", "binary"),
        OutcomeSpec("y_badge", "Worn campaign badge/sticker", "binary"),
        OutcomeSpec("y_pbldmn", "Taken part in public demonstration", "binary"),
        OutcomeSpec("y_bctprd", "Boycotted products", "binary"),
    ]

    controls = ["agea", "gndr", "hinctnta"]

    rows = []
    for spec in outcomes:
        res = fit_iv_interaction_absorb(df, spec.name, controls=controls)
        rows.append(summarize_result(res, spec.name, spec.label))

    tab = pd.DataFrame(rows)
    (out_dir / "iv_bartik_absorb_main_table.csv").write_text(tab.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_bartik_absorb_main_table.md").write_text(to_markdown_table(tab) + "\n", encoding="utf-8")

    # Mechanism + placebos
    mech_placebo = [
        ("log_nwspol", "mechanism: log(1+nwspol minutes)", ["agea", "gndr", "hinctnta"]),
        ("brncntr_bin", "placebo: born in country (1/0)", ["agea", "gndr", "hinctnta"]),
        ("gndr", "placebo: gender (numeric code)", ["agea", "hinctnta"]),
    ]
    mp_rows = []
    for y, label, ctrls in mech_placebo:
        if y not in df.columns:
            continue
        res = fit_iv_interaction_absorb(df, y, controls=ctrls)
        mp_rows.append(summarize_result(res, y, label))
    mp = pd.DataFrame(mp_rows)
    (out_dir / "iv_bartik_absorb_placebo_mechanism.md").write_text(to_markdown_table(mp) + "\n", encoding="utf-8")

    # Short text summary
    txt = []
    txt.append("Bartik-IV (baseline density × country-year broadband shock) with absorbed region FE + country×year FE")
    txt.append(f"Rows in estimation pool: {len(df):,}")
    txt.append(f"Countries: {df['cntry'].nunique()}  Regions: {df['region'].nunique()}  Years: {sorted(df['survey_year'].unique().tolist())}")
    txt.append("Outputs:")
    txt.append(f"- {out_dir / 'iv_bartik_absorb_main_table.md'}")
    txt.append(f"- {out_dir / 'iv_bartik_absorb_placebo_mechanism.md'}")
    (out_dir / "iv_bartik_absorb_checks.txt").write_text("\n".join(txt) + "\n", encoding="utf-8")
    print(f"Wrote: {out_dir / 'iv_bartik_absorb_main_table.md'}")


if __name__ == "__main__":
    main()
